#pragma once
#include "dimension.h"
#include "esmd_types.h"
#include <sys/cdefs.h>

#define shake_xuc shake_tmp
#define shake_vp shake_tmp
/* Use v and shake_tmp temporary for CG */
#define cg_h v
#define cg_g shake_tmp
//#define cg_x0 v
/* nguest can be overwritten when cell exporting */
#define nbytes_export nguest
#define MAX_NEIGH ((MAX_NN * 2 + 1) * (MAX_NN * 2 + 1) * (MAX_NN * 2 + 1) + 2)
enum overflow{
  OVF_WRAP,
  OVF_SKIP,
  OVF_KEEP
};
struct shake_rec{
  real rsq[3];
};
typedef struct shake_t {
  int type, idx[3];
  real rsq[3];
  #ifdef WITH_LINCS
  real Sdiag[3];
  real r0[3];
  #endif
} shake_t;
struct lincs_rec {
  real r0[3];
};
#include "rigid_rec.hpp"
#define rigid_tmp shake_tmp
namespace esmd{
  template <typename Key, typename ValType>
  class htab;
};
class tag_key;
typedef struct top_list {
  int bonds[MAX_BOND_CALC_CELL][2];
  int angles[MAX_ANGLE_CALC_CELL][3];
  int toris[MAX_TORI_CALC_CELL][4];
  int imprs[MAX_IMPR_CALC_CELL][4];
  real r0[MAX_BOND_CALC_CELL], kr[MAX_BOND_CALC_CELL];
  real theta0[MAX_ANGLE_CALC_CELL], ktheta[MAX_ANGLE_CALC_CELL];
  real r0ub[MAX_ANGLE_CALC_CELL], kub[MAX_ANGLE_CALC_CELL];
  real phi0[MAX_TORI_CALC_CELL], kphi[MAX_TORI_CALC_CELL], nperiod[MAX_TORI_CALC_CELL];
  real imphi0[MAX_IMPR_CALC_CELL], kimphi[MAX_IMPR_CALC_CELL];
  int nbond, nangle, ntori, nimpr;
} top_list_t;

struct celldata_t {
  /*static data of nonb input*/
  real q[CELL_CAP];

  /*dynamic data of nonb input, integrator input*/
  vec<real> x[CELL_CAP + MAX_CELL_GUEST];
  vec<real> xinit[CELL_CAP];
  /*bonded input*/
  long tag[CELL_CAP + MAX_CELL_GUEST];
  long mol[CELL_CAP];
  /*type of atom*/
  int t[CELL_CAP + MAX_CELL_GUEST];
  /*tag for bonded atom, excluded atom and scaled(1-4) atom*/

  rigid_rec rigid[CELL_CAP];

  /*we also record an atom id list for efficiency during computation*/
  int excl_id[MAX_SCAL_CELL][2];
  int scal_id[MAX_EXCL_CELL][2];
  int first_excl_cell[MAX_NEIGH];
  int first_scal_cell[MAX_NEIGH];
  int first_guest_cell[MAX_NEIGH];

  /*a->{bcde}, impr: a(cde), (123)*/
  

  /*data for integrator input*/
  vec<real> f[CELL_CAP + MAX_CELL_GUEST];
  vec<real> shake_tmp[CELL_CAP + MAX_CELL_GUEST];
  vec<real> v[CELL_CAP + MAX_CELL_GUEST];
  real mass[CELL_CAP], rmass[CELL_CAP + MAX_CELL_GUEST];
  int guest_id[MAX_CELL_GUEST];

  /*lower left corner of cell, x<-coord of atom - basis*/
  vec<real> basis; //, len;
  int natom;
  #ifdef __sw__
  int first_frep;
  long pe_mask;
  long rep_init_mask;
  #endif
  int nexport, nguest;
  #ifdef __sw__
  //int nrguest;
  #endif
  //int nbonded_export, nchain2_export, nexcl_export, nscal_export, nimpr_export __attribute_deprecated__;

};
enum cell_fields{
  CF_N,
  CF_X,
  CF_V,
  CF_F,
  CF_Q,
  CF_TAG,
  CF_TYPE,
  CF_MASS,
  CF_RIGID
};
// const int MAX_NNDIST = 3;
// const int MAX_NN_CELL = 32 * CELL_CAP;
// struct nnlist {
//   int last_nn[CELL_CAP][MAX_NNDIST + 1];
//   long nns[MAX_NN_CELL];
// };
struct topology_grids;
struct cellgrid_t {
  /*Pointer to cells*/
  celldata_t *cells;
  // nnlist *nns;
  /*number of cells to be searched around*/
  int nn;
  /*number of local/all cells of current process in each direction*/
  vec<int> nlocal, nall;
  /*lower/upper bound of current process*/
  box<int> dim;
  /*skin, cell length, and recip of cell length in each direction*/
  vec<real> skin, len, rlen;
  /*lower/upper bound of global/local box*/
  box<real> gbox, lbox;
  /*cutoff radius*/
  real rcut;
  long natoms;
  int updated;
  void *arch_data;
  topology_grids *topo;
  esmd::htab<tag_key, vec<int>> *tag_map;
  esmd::htab<tag_key, int> *type_map;

  __always_inline celldata_t &cell_at(int x, int y, int z){
    int offset = ((x + nn) * nall.y + (y + nn)) * nall.z + z + nn;    
    return cells[offset];
  }
};

//#include "bonded.h"
/*hdcell is a condensed hash of cell delta, especially used in company with firstexcl/firstscal array. enc_dcell is not dense, but easier to decode*/
#define hdcell_internal(nn, dx, dy, dz) (((dx + nn) * (nn * 2 + 1) + (dy + nn)) * (nn * 2 + 1) + (dz + nn))
#define hdcell(nn, dx, dy, dz) hdcell_internal((nn), (dx), (dy), (dz))
#define enc_dcell_internal(nn, dx, dy, dz) ((dx + nn) << 8 | (dy + nn) << 4 | (dz + nn))
#define enc_dcell(nn, dx, dy, dz) enc_dcell_internal((nn), (dx), (dy), (dz))

/*loop over all cells in current process (including ghosts)
 *grid is cellgrid, i, j, k is used for cell index in each axis,
 *cell is the variable name for cell pointer
*/

#define FOREACH_CELL(grid, i, j, k, cell)                       \
  for (int i = (grid)->dim.lo.x; i < (grid)->dim.hi.x; i++)     \
    for (int j = (grid)->dim.lo.y; j < (grid)->dim.hi.y; j++)   \
      for (int k = (grid)->dim.lo.z; k < (grid)->dim.hi.z; k++) \
        for (celldata_t *cell = get_cell_xyz((grid), i, j, k); cell != NULL; cell = NULL)

/*loop over local cells in current process (without ghosts)*/
#define FOREACH_LOCAL_CELL(grid, i, j, k, cell)  \
  for (int i = 0; i < (grid)->nlocal.x; i++)     \
    for (int j = 0; j < (grid)->nlocal.y; j++)   \
      for (int k = 0; k < (grid)->nlocal.z; k++) \
        for (celldata_t *cell = get_cell_xyz((grid), i, j, k); cell != NULL; cell = NULL)
/*loop over neighbor cell of ijk*/
#define FOREACH_NEIGHBOR(grid, i, j, k, dx, dy, dz, cell) \
  for (int dx = -(grid)->nn; dx <= (grid)->nn; dx++)      \
    for (int dy = -(grid)->nn; dy <= (grid)->nn; dy++)    \
      for (int dz = -(grid)->nn; dz <= (grid)->nn; dz++)  \
        for (celldata_t *cell = get_cell_xyz((grid), i + dx, j + dy, k + dz); cell != 0; cell = 0)
/*loop over neighbor cell of ijk with distance 1, this is used in finding bonds and cell import*/
#define FOREACH_NEARNEIGHBOR(grid, i, j, k, dx, dy, dz, cell) \
  for (int dx = -1; dx <= 1; dx++)                            \
    for (int dy = -1; dy <= 1; dy++)                          \
      for (int dz = -1; dz <= 1; dz++)                        \
        for (celldata_t *cell = get_cell_xyz((grid), i + dx, j + dy, k + dz); cell != 0; cell = 0)
/*loop over half of neighbor of ijk*/
#define FOREACH_NEIGHBOR_HALF_SHELL(grid, i, j, k, dx, dy, dz, cell)                \
  for (int dx = -(grid)->nn; dx <= 0; dx++)                                         \
    for (int dy = -(grid)->nn; dy <= (dx == 0 ? 0 : (grid)->nn); dy++)              \
      for (int dz = -(grid)->nn; dz <= (dx == 0 && dy == 0 ? 0 : (grid)->nn); dz++) \
        for (celldata_t *cell = get_cell_xyz((grid), i + dx, j + dy, k + dz); cell != 0; cell = 0)
/*get cell with coord x, y, z*/
static inline celldata_t *get_cell_xyz(cellgrid_t *grid, int x, int y, int z) {
  return grid->cells + ((x + grid->nn) * grid->nall.y + (y + grid->nn)) * grid->nall.z + z + grid->nn;
}
template<bool HasGhost>
static inline int get_offset_xyz(cellgrid_t *grid, int x, int y, int z) {
  if (HasGhost)
    return ((x + grid->nn) * grid->nall.y + (y + grid->nn)) * grid->nall.z + z + grid->nn;
  else
    return (x * grid->nlocal.y + y) * grid->nlocal.z + z;
}
template<bool HasGhost>
static inline int get_offset_xyz(const cellgrid_t &grid, int x, int y, int z) {
  if (HasGhost)
    return ((x + grid.nn) * grid.nall.y + (y + grid.nn)) * grid.nall.z + z + grid.nn;
  else
    return (x * grid.nlocal.y + y) * grid.nlocal.z + z;
}
/*reverse of enc_dcell*/
static inline celldata_t *get_dcell(cellgrid_t *grid, celldata_t *cell, int dcell) {
  int dx = (dcell >> 8 & 15) - grid->nn;
  int dy = (dcell >> 4 & 15) - grid->nn;
  int dz = (dcell >> 0 & 15) - grid->nn;
  return cell + (dx * grid->nall.y + dy) * grid->nall.z + dz;
}
struct bond_graph_t;
struct impr_index_t;

//function signatures
void build_cells(cellgrid_t *grid, int nn, real rcut, real skin, box<real> * gbox, box<real> * lbox,
                 int natoms, double (*x)[3], real *q, int *t, real *mass,
                 bond_graph_t *graph,
                 impr_index_t *imidx,
                 long (*excls)[MAX_EXCLS_ATOM], long (*scals)[MAX_SCALS_ATOM], long (*chain2)[MAX_CHAIN2_ATOM][2], int ovf);
void build_cells_bondonly(cellgrid_t *grid, int nn, real rcut, real skin, box<real> * gbox, box<real> * lbox,
                 int natoms, double (*x)[3], real *q, int *t, real *mass,
                 bond_graph_t *graph,
                 impr_index_t *imidx,
                 int overflow);
int cell_check(cellgrid_t *grid);
void cell_export(cellgrid_t *grid);
void cell_import(cellgrid_t *grid);
void cell_sort(cellgrid_t *grid);
void cell_bfs_nn(cellgrid_t *grid, int nnear);
//end function signatures

//#include "listed.hpp"
